
import time
from utils import fixed_smooth, slide_smooth
import torch
import numpy as np
from sklearn.metrics import auc, roc_curve, confusion_matrix, precision_recall_curve
import os
import pickle
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score



def infer_func(model, dataloader, gt, logger, cfg, file, scale, length):
#    with open('list_file.pkl', 'rb') as f:
#        loaded_list = pickle.load(f)
#        idxs = torch.where(torch.tensor(loaded_list) < 0.2)[0]
    st = time.time()
    count = []
    count_1 = []
    with torch.no_grad():
        model.eval()
        pred = torch.zeros(0).cuda()
        normal_preds = torch.zeros(0).cuda()
        normal_labels = torch.zeros(0).cuda()
        gt_tmp = torch.tensor(gt.copy()).cuda()

        for i, (v_input, name) in enumerate(dataloader):
            v_input = v_input.float().cuda(non_blocking=True)
            seq_len = torch.sum(torch.max(torch.abs(v_input), dim=2)[0] > 0, 1)
            logits, _, _ = model(v_input, seq_len, scale)
            logits = torch.mean(logits, 0)
            logits = logits.squeeze(dim=-1)

            seq = len(logits)
            # logits = logits * torch.max(logits) # loaded_list[i]
            if cfg.smooth == 'fixed':
                logits = fixed_smooth(logits, cfg.kappa)
            elif cfg.smooth == 'slide':
                logits = slide_smooth(logits, cfg.kappa)
            else:
                pass

            logits = logits[length:seq]
   
            # logits = logits * torch.max(logits)
            # if i > 140:
            #     logits = logits * 0
            # else:
            
            # logits = logits * torch.max(logits)
            # ori_logits, _, _ = model(v_input[:, length:, :], seq_len-length)
            # ori_logits = torch.mean(ori_logits, 0)
            # ori_logits = ori_logits.squeeze(dim=-1)
            # print(i, round(loaded_list[i], 3), round(torch.max(logits).item(), 3), round(torch.max(ori_logits).item() - torch.max(logits).item(), 3))
            # count.append(torch.max(logits).item())
            # count_1.append(round(torch.max(ori_logits).item() - torch.max(logits).item(), 3))
            # if i in idxs:    
            #    logits = logits*torch.max(logits) # count.append(torch.max(logits).item())
            # if loaded_list[i] < 0.1 and torch.max(logits) < 0.1:
            #     logits = logits*0
            # elif loaded_list[i] > 0.8 and torch.max(logits) > 0.8:
            #     logits = logits / torch.max(logits)
            # else:
            #     logits = logits
             #torch.max(logits)
            # if logits.max() < 0.08 and loaded_list[i] < 0.08:
            #     logits = logits * 0
            # if torch.max(logits) < 0.1:
            #     logits = logits * 0
            pred = torch.cat((pred, logits))
            labels = gt_tmp[: seq_len[0]*16]
            if torch.sum(labels) == 0:
                normal_labels = torch.cat((normal_labels, labels))
                normal_preds = torch.cat((normal_preds, logits))
            gt_tmp = gt_tmp[seq_len[0]*16:]
        # plot_scatter(loaded_list, count, count_1)
        with open('max.pkl', 'wb') as file:
            pickle.dump(count, file)
        pred = list(pred.cpu().detach().numpy())
        # far = cal_false_alarm(normal_labels, normal_preds)
        fpr, tpr, _ = roc_curve(list(gt), np.repeat(pred, 16))
        roc_auc = auc(fpr, tpr)
        pre, rec, _ = precision_recall_curve(list(gt), np.repeat(pred, 16))
        pr_auc = auc(rec, pre)
        far = 0.04
    time_elapsed = time.time() - st
    print('offline AUC:{:.4f} AP:{:.4f} FAR:{:.4f} | Complete in {:.0f}m {:.0f}s\n'.format(
        roc_auc, pr_auc, far, time_elapsed // 60, time_elapsed % 60))
    return roc_auc, fpr, tpr

def plot_scatter(x, y, z):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(x[:140], y[:140], z[:140], c='red', label='First 140 points')
    ax.scatter(x[140:], y[140:], z[140:], c='blue', label='Points after 140')

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('3D Scatter Plot')
    ax.legend()

    X = np.column_stack((y, z))
    Y = np.concatenate((np.zeros(140), np.ones(150)))

    svm = SVC(kernel='linear')
    svm.fit(X, Y)

    y_pred = svm.predict(X)

    accuracy = accuracy_score(Y, y_pred)
    print("Accuracy:", accuracy)

    plt.show()
    plt.savefig('dot.png')
